#include "ToRORd_fkatp_endo.h"
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
real max_step;
real min_step;
real abstol;
real reltol;
bool adpt;
real *ode_dt, *ode_previous_dt, *ode_time_new;

GET_CELL_MODEL_DATA(init_cell_model_data) {

    if(get_initial_v)
        cell_model->initial_v = INITIAL_V;
    if(get_neq)
        cell_model->number_of_ode_equations = NEQ; //for count and m
}

SET_ODE_INITIAL_CONDITIONS_CPU(set_model_initial_conditions_cpu) {

    log_info("Using ToRORd_fkatp_endo CPU model\n");

    uint32_t num_cells = solver->original_num_cells;
	solver->sv = (real*)malloc(NEQ*num_cells*sizeof(real));

    max_step = solver->max_dt;
    min_step = solver->min_dt;
    abstol   = solver->abs_tol;
    reltol   = solver->rel_tol;
    adpt     = solver->adaptive;

    if(adpt) {
        ode_dt = (real*)malloc(num_cells*sizeof(real));

        OMP(parallel for)
        for(int i = 0; i < num_cells; i++) {
            ode_dt[i] = solver->min_dt;
        }

        ode_previous_dt = (real*)calloc(num_cells, sizeof(real));
        ode_time_new    = (real*)calloc(num_cells, sizeof(real));
        log_info("Using Adaptive Euler model to solve the ODEs\n");
    } else {
        log_info("Using Euler model to solve the ODEs\n");
    }

    real *ischFactor;
    real ischFactor_size = num_cells*sizeof(real);

    real *APEXB;
    real APEXB_size = num_cells*sizeof(real);
	
	real *HCMRE;
	real HCMRE_size = num_cells*sizeof(real);
	
	int *CELLT;
	real CELLT_size = num_cells*sizeof(int);
	
	real *ctrlNormEndo;
	real *ctrlIschEndo;
	real *hcmNormEndo;
	real *hcmIschEndo;
	real *ctrlNormEpi;
	real *ctrlIschEpi;
	real *hcmNormEpi;
	real *hcmIschEpi;


    struct extra_data_for_HCM* extra_data_from_solver = (struct extra_data_for_HCM*)solver->ode_extra_data;
    bool deallocate = false;
	
    if(solver->ode_extra_data) {
		ischFactor = extra_data_from_solver->ISCH;
		HCMRE = extra_data_from_solver->HCMRE;
		CELLT = extra_data_from_solver->CELLT;
		APEXB = extra_data_from_solver->APEXB;
		
		ctrlNormEndo = extra_data_from_solver->ctrlNormEndo;
		ctrlIschEndo = extra_data_from_solver->ctrlIschEndo;
		hcmNormEndo = extra_data_from_solver->hcmNormEndo;
		hcmIschEndo = extra_data_from_solver->hcmIschEndo;
		ctrlNormEpi = extra_data_from_solver->ctrlNormEpi;
		ctrlIschEpi = extra_data_from_solver->ctrlIschEpi;
		hcmNormEpi = extra_data_from_solver->hcmNormEpi;
		hcmIschEpi = extra_data_from_solver->hcmIschEpi;

    }
    else { // Default values for healthy cell
        ischFactor = (real*) malloc(ischFactor_size);
	APEXB = (real*) malloc(APEXB_size);
		HCMRE = (real*) malloc(HCMRE_size);
		CELLT = (int*) malloc(CELLT_size);

        for(uint64_t i = 0; i < num_cells; i++) {
			ischFactor[i] = 1.0; // Default to non-ischemic
			HCMRE[i] = 1.0; // Default to non-HCM
			CELLT[i] = 1.0; // Default to endo
			APEXB[i] = 1.0; // Default to no apex base gradient
        }
	


        deallocate = true;
    }

    OMP(parallel for)
    for(uint32_t i = 0; i < num_cells; i++) {

        real *sv = &solver->sv[i * NEQ];
		
		// ENDO
		if(ischFactor[i]<0.5 && HCMRE[i]>=1.5 && CELLT[i]==1){ // HCM ISCH
			for(int i = 0; i < 43; i++){sv[i] = hcmIschEndo[i];}
		}
		else if(ischFactor[i]<0.5 && HCMRE[i]<1.5 && CELLT[i]==1){ // Control ISCH
			for(int i = 0; i < 43; i++){sv[i] = ctrlIschEndo[i];}
		}
		else if(ischFactor[i]>=0.5 && HCMRE[i]>=1.5 && CELLT[i]==1){ // HCM Normal
			for(int i = 0; i < 43; i++){sv[i] = hcmNormEndo[i];}
		}
		else if(ischFactor[i]>=0.5 && HCMRE[i]<1.5 && CELLT[i]==1){ // Control Normal
			for(int i = 0; i < 43; i++){sv[i] = ctrlNormEndo[i];}
		}
		
		/// EPI
		else if(ischFactor[i]<0.5 && HCMRE[i]>=1.5 && CELLT[i]==3){ // HCM ISCH
			for(int i = 0; i < 43; i++){sv[i] = hcmIschEpi[i];}
		}
		else if(ischFactor[i]<0.5 && HCMRE[i]<1.5 && CELLT[i]==3){ // Control ISCH
			for(int i = 0; i < 43; i++){sv[i] = ctrlIschEpi[i];}
		}
		else if(ischFactor[i]>=0.5 && HCMRE[i]>=1.5 && CELLT[i]==3){ // HCM Normal
			for(int i = 0; i < 43; i++){sv[i] = hcmNormEpi[i];}
		}
		else if(ischFactor[i]>=0.5 && HCMRE[i]<1.5 && CELLT[i]==3){ // Control Normal
			for(int i = 0; i < 43; i++){sv[i] = ctrlNormEpi[i];}
		}
		
		
		
	}
	
		if(deallocate) free(ischFactor);
		if(deallocate) free(HCMRE);
		if(deallocate) free(CELLT);
}


SOLVE_MODEL_ODES(solve_model_odes_cpu) {

    uint32_t sv_id;

	real *APEXB;
    real *ischFactor;
	real *HCMRE;
	int *CELLT;

    size_t num_cells_to_solve = ode_solver->num_cells_to_solve;
    uint32_t * cells_to_solve = ode_solver->cells_to_solve;
    real *sv = ode_solver->sv;
    real dt = ode_solver->min_dt;
    uint32_t num_steps = ode_solver->num_steps;

    int num_extra_parameters = 15;
    real extra_par[num_extra_parameters];
	
    real ischFactor_size = num_cells_to_solve*sizeof(real);
	real HCMRE_size = num_cells_to_solve*sizeof(real);
	real CELLT_size = num_cells_to_solve*sizeof(int);
	real APEXB_size = num_cells_to_solve*sizeof(real);

    struct extra_data_for_HCM* extra_data_from_solver = (struct extra_data_for_HCM*)ode_solver->ode_extra_data;
    bool deallocate = false;
	
    if(ode_solver->ode_extra_data) {
        ischFactor = extra_data_from_solver->ISCH;
		HCMRE = extra_data_from_solver->HCMRE;
		CELLT = extra_data_from_solver->CELLT;
		APEXB = extra_data_from_solver->APEXB;
        extra_par[0] = extra_data_from_solver->INaFactor;
        extra_par[1] = extra_data_from_solver->ICaLFactor;
        extra_par[2] = extra_data_from_solver->Ko;
        extra_par[3] = extra_data_from_solver->f;
		
		extra_par[4] = extra_data_from_solver->mCaL;
		extra_par[5] = extra_data_from_solver->mNa;
		extra_par[6] = extra_data_from_solver->mto;
		extra_par[7] = extra_data_from_solver->mNaL;
		extra_par[8] = extra_data_from_solver->mKr;
		extra_par[9] = extra_data_from_solver->mKs;
		extra_par[10] = extra_data_from_solver->mK1;
		extra_par[11] = extra_data_from_solver->mNaCa;
		extra_par[12] = extra_data_from_solver->mNaK;
		extra_par[13] = extra_data_from_solver->mRel;
		extra_par[14] = extra_data_from_solver->mUp;
    }
    else {
        // Default values for a healthy cell ///////////
		extra_par[0] = 1.0f;
        extra_par[1] = 1.0f;
        extra_par[2] = 5.0f;
        extra_par[3] = 0.0f;
		
		extra_par[4] = 1.0f;
		extra_par[5] = 1.0f;
		extra_par[6] = 1.0f;
		extra_par[7] = 1.0f;
		extra_par[8] = 1.0f;
		extra_par[9] = 1.0f;
		extra_par[10] = 1.0f;
		extra_par[11] = 1.0f;
		extra_par[12] = 1.0f;
		extra_par[13] = 1.0f;
		extra_par[14] = 1.0f;
		
		ischFactor = (real*) malloc(ischFactor_size);
		HCMRE = (real*) malloc(HCMRE_size);
		CELLT = (int*) malloc(CELLT_size);
		APEXB = (real*) malloc(APEXB_size);

        for(uint64_t i = 0; i < num_cells_to_solve; i++) {
			ischFactor[i] = 1.0; // Default to non-ischemic
			HCMRE[i] = 1.0; // Default to non-HCM
			CELLT[i] = 1.0; // Default to endo
			APEXB[i] = 1.0;
        }
        deallocate = true;
    }

    #pragma omp parallel for private(sv_id)
    for (u_int32_t i = 0; i < num_cells_to_solve; i++) {
			if(cells_to_solve)
				sv_id = cells_to_solve[i];
			else
				sv_id = i;

			if(adpt) {

				solve_forward_euler_cpu_adpt(sv + (sv_id * NEQ), stim_currents[i], current_t + dt, sv_id, ischFactor[i], extra_par, HCMRE[i], CELLT[i], APEXB[i]);
			}
			else {
				for (int j = 0; j < num_steps; ++j) {
					solve_model_ode_cpu(dt, sv + (sv_id * NEQ), stim_currents[i], ischFactor[i], extra_par, HCMRE[i], CELLT[i], APEXB[i]);
				}

			}
		}
    
    if(deallocate) free(ischFactor);
	if(deallocate) free(HCMRE);
	if(deallocate) free(CELLT);
}

void solve_model_ode_cpu(real dt, real *sv, real stim_current, real ischFactor, real *extra_parameters, real HCMRE, int CELLT, real APEXB)  {

    real rY[NEQ], rDY[NEQ];

    for(int i = 0; i < NEQ; i++)
        rY[i] = sv[i];

    RHS_cpu(rY, rDY, stim_current, dt, ischFactor, extra_parameters, HCMRE, CELLT, APEXB);

    for(int i = 0; i < NEQ; i++)
        sv[i] = dt*rDY[i] + rY[i];
}

void solve_forward_euler_cpu_adpt(real *sv, real stim_curr, real final_time, int sv_id, real ischFactor, real *extra_parameters, real HCMRE, int CELLT, real APEXB) {

    const real _beta_safety_ = 0.8;
    int numEDO = NEQ;
    real rDY[numEDO];
    real _tolerances_[numEDO];
    real _aux_tol = 0.0;
    ode_previous_dt[sv_id] = ode_dt[sv_id];
    real edos_old_aux_[numEDO];
    real edos_new_euler_[numEDO];
    real *_k1__ = (real*) malloc(sizeof(real)*numEDO);
    real *_k2__ = (real*) malloc(sizeof(real)*numEDO);
    real *_k_aux__;
    real *dt = &ode_dt[sv_id];
    real *time_new = &ode_time_new[sv_id];
    real *previous_dt = &ode_previous_dt[sv_id];
    if(*time_new + *dt > final_time) {
       *dt = final_time - *time_new;
    }

    RHS_cpu(sv, rDY, stim_curr, *dt, ischFactor, extra_parameters, HCMRE, CELLT, APEXB);
    *time_new += *dt;

    for(int i = 0; i < numEDO; i++){
        _k1__[i] = rDY[i];
    }

    const double __tiny_ = pow(abstol, 2.0);
    int count = 0;
    int count_limit = (final_time - *time_new)/min_step;
    int aux_count_limit = count_limit+2000000;
    if(aux_count_limit > 0) {
        count_limit = aux_count_limit;
    }

    while(1) {
        for(int i = 0; i < numEDO; i++) {
            edos_old_aux_[i] = sv[i];
            edos_new_euler_[i] = _k1__[i] * *dt + edos_old_aux_[i];
            sv[i] = edos_new_euler_[i];
        }

        *time_new += *dt;
        RHS_cpu(sv, rDY, stim_curr, *dt, ischFactor, extra_parameters, HCMRE, CELLT, APEXB);
        *time_new -= *dt;

        double greatestError = 0.0, auxError = 0.0;
        for(int i = 0; i < numEDO; i++) {
            _k2__[i] = rDY[i];
            _aux_tol = fabs(edos_new_euler_[i])*reltol;
            _tolerances_[i] = (abstol > _aux_tol )?abstol:_aux_tol;
            auxError = fabs(( (*dt/2.0)*(_k1__[i] - _k2__[i])) / _tolerances_[i]);

            greatestError = (auxError > greatestError) ? auxError : greatestError;
        }
        ///adapt the time step
        greatestError += __tiny_;
        *previous_dt = *dt;
        ///adapt the time step
        *dt = _beta_safety_ * (*dt) * sqrt(1.0f/greatestError);

        if (*time_new + *dt > final_time) {
            *dt = final_time - *time_new;
        }

        //it doesn't accept the solution
        if ( count < count_limit  && (greatestError >= 1.0f)) {
            //restore the old values to do it again
            for(int i = 0;  i < numEDO; i++) {
                sv[i] = edos_old_aux_[i];
            }

            count++;
            //throw the results away and compute again
        } else{//it accepts the solutions


            if(greatestError >=1.0) {
                printf("Accepting solution with error > %lf \n", greatestError);
            }

            //printf("%e %e\n", _ode->time_new, edos_new_euler_[0]);
            if (*dt < min_step) {
                *dt = min_step;
            }

            else if (*dt > max_step && max_step != 0) {
                *dt = max_step;
            }

            if (*time_new + *dt > final_time) {
                *dt = final_time - *time_new;
            }

            _k_aux__ = _k2__;
            _k2__	= _k1__;
            _k1__	= _k_aux__;

            //it steps the method ahead, with euler solution
            for(int i = 0; i < numEDO; i++){
                sv[i] = edos_new_euler_[i];
            }

            if(*time_new + *previous_dt >= final_time){
                if((fabs(final_time - *time_new) < 1.0e-5) ){
                    break;
                }else if(*time_new < final_time){
                    *dt = *previous_dt = final_time - *time_new;
                    *time_new += *previous_dt;
                    break;

                }else{
                    printf("Error: time_new %.20lf final_time %.20lf diff %e \n", *time_new , final_time, fabs(final_time - *time_new) );
                    break;
                }
            }else{
                *time_new += *previous_dt;
            }

        }
    }

    free(_k1__);
    free(_k2__);
}

void RHS_cpu(const real *sv, real *rDY_, real stim_current, real dt, real ischFactor, real *extra_parameters, real HCMRE, int CELLT, real APEXB) {

    //State variables
    const real v_old_ = sv[0];
    const real CaMKt_old_ = sv[1];
    const real nai_old_ = sv[2];
    const real nass_old_ = sv[3];
    const real ki_old_ = sv[4];
    const real kss_old_ = sv[5];
    const real cai_old_ = sv[6];
    const real cass_old_ = sv[7];
    const real cansr_old_ = sv[8];
    const real cajsr_old_ = sv[9];
    const real m_old_ = sv[10];
    const real h_old_ = sv[11];
    const real j_old_ = sv[12];
    const real hp_old_ = sv[13];
    const real jp_old_ = sv[14];
    const real mL_old_ = sv[15];
    const real hL_old_ = sv[16];
    const real hLp_old_ = sv[17];
    const real a_old_ = sv[18];
    const real iF_old_ = sv[19];
    const real iS_old_ = sv[20];
    const real ap_old_ = sv[21];
    const real iFp_old_ = sv[22];
    const real iSp_old_ = sv[23];
    const real d_old_ = sv[24];
    const real ff_old_ = sv[25];
    const real fs_old_ = sv[26];
    const real fcaf_old_ = sv[27];
    const real fcas_old_ = sv[28];
    const real jca_old_ = sv[29];
    const real ffp_old_ = sv[30];
    const real fcafp_old_ = sv[31];
    const real nca_ss_old_ = sv[32];
    const real nca_i_old_ = sv[33];
    const real C3_old_ = sv[34];
    const real C2_old_ = sv[35];
    const real C1_old_ = sv[36];
    const real O_old_ = sv[37];
    const real I_old_ = sv[38];
    const real xs1_old_ = sv[39];
    const real xs2_old_ = sv[40];
    const real Jrel_np_old_ = sv[41];
    const real Jrel_p_old_ = sv[42];
	
	//printf("mapping %lf ", mapping);

    #include "ToROrd_common.inc.c"
}
